from network import SketchANet
import torch
import numpy as np
import os
from hyper_params import hp
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2
from PIL import Image
import random


SEED=123
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=True
np.random.seed(SEED)
random.seed(SEED)

model = SketchANet()
model.load_state_dict(torch.load('./model_save/model_epoch_15000.pth'))
trans = transforms.Compose([transforms.Resize(225), transforms.ToTensor()])
#trans = transforms.ToTensor()


class Sketch_dataset(Dataset):
    def __init__(self, root):
        super(Sketch_dataset, self).__init__()
        self.root = root
        self.imgs = []
        self.labels = []

        for i,cat in enumerate(hp.category):
            cat = cat.split('.')[0]
            path = os.path.join(self.root, cat)

            print(f"{cat} loading, label {i}")
            dir_name_list = os.listdir(path + '/')
            for name in dir_name_list:
                img = Image.open(path+'/'+name).convert('1')
                img = img.resize((225, 225))

                #img.show()
                imgArray = np.abs(np.array(img)-1).astype(np.int8)
                img = Image.fromarray(imgArray*255).convert('1')
                #img.show()
                self.labels.append(i)
                self.imgs.append(trans(img)*255)
        print(f"{len(self.labels)} imgs loaded")

    def __getitem__(self, index):
        return self.imgs[index], self.labels[index]

    def __len__(self):
        return len(self.labels)


test_dataset = Sketch_dataset('../Lmser2rwp/results/0.5/sketch/')
#test_dataset = Sketch_dataset('../SketchLattice/results/models_32_150/visualize/0.5/sketch/')
dataloader = DataLoader(test_dataset, batch_size=200)

model = model.cuda()
model.eval()

if __name__ == '__main__':
    correct = 0
    for i, (imgs, labels) in enumerate(dataloader):
        Y = labels.cuda()
        X = imgs.to(torch.float32).cuda()
        X = X.view(-1, 1, hp.graph_picture_size, hp.graph_picture_size)
        output = model(X)
        _, predicted = torch.max(output, 1)
        #print(predicted)
        correct += (predicted == Y).sum().item()
        print(f'correct:{correct}')

    acc = correct/len(test_dataset)
    print(f"accuracy{acc}")

